{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f08a8838",
   "metadata": {
    "papermill": {
     "duration": 5.790567,
     "end_time": "2024-06-27T02:45:29.465864",
     "exception": false,
     "start_time": "2024-06-27T02:45:23.675297",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:24:25.616352300Z",
     "start_time": "2024-10-23T02:24:23.552393600Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\admin\\AppData\\Roaming\\Python\\Python310\\site-packages\\pandas\\core\\arrays\\masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
      "  from pandas.core import (\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "import torch\n",
    "import torchvision\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms as tfs\n",
    "from torchvision.utils import save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d7452a0b",
   "metadata": {
    "papermill": {
     "duration": 1.908979,
     "end_time": "2024-06-27T02:45:31.378366",
     "exception": false,
     "start_time": "2024-06-27T02:45:29.469387",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:30:49.263413600Z",
     "start_time": "2024-10-23T02:30:20.184387200Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:997)>\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./mnist\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/9912422 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "be39c1688f9b4ed6aeaeccd03cb82137"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./mnist\\MNIST\\raw\\train-images-idx3-ubyte.gz to ./mnist\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:997)>\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./mnist\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/28881 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d0d1ea5e40c84b6a94a031966eaaa63c"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./mnist\\MNIST\\raw\\train-labels-idx1-ubyte.gz to ./mnist\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:997)>\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./mnist\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/1648877 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "69f62308e89e47b79de6f1b84440f945"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./mnist\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to ./mnist\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Failed to download (trying next):\n",
      "<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:997)>\n",
      "\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./mnist\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/4542 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5114599ebf60412c927b98b6c3d2ed0c"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./mnist\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to ./mnist\\MNIST\\raw\n"
     ]
    }
   ],
   "source": [
    "im_tfs = tfs.Compose([\n",
    "    tfs.ToTensor(),\n",
    "    tfs.Normalize((0.5, ), (0.5,))\n",
    "#     tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化\n",
    "])\n",
    "train_set = torchvision.datasets.MNIST(\n",
    "    root=\"./mnist\", train=True, download=True, transform=im_tfs\n",
    ")\n",
    "val_set = torchvision.datasets.MNIST(\n",
    "    root=\"./mnist\", train=False, download=True, transform=im_tfs\n",
    ")\n",
    "\n",
    "# train_set = MNIST('/kaggle/working/mnist', transform=im_tfs)\n",
    "train_data = DataLoader(train_set, batch_size=128, shuffle=True)\n",
    "test_data= DataLoader(val_set,batch_size=128,shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3d42662b",
   "metadata": {
    "papermill": {
     "duration": 0.019369,
     "end_time": "2024-06-27T02:45:31.403148",
     "exception": false,
     "start_time": "2024-06-27T02:45:31.383779",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:31:21.695616700Z",
     "start_time": "2024-10-23T02:31:21.690977500Z"
    }
   },
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(VAE, self).__init__()\n",
    "\n",
    "        self.fc1 = nn.Linear(784, 400)\n",
    "        self.fc21 = nn.Linear(400, 20) # mean\n",
    "        self.fc22 = nn.Linear(400, 20) # var\n",
    "        self.fc3 = nn.Linear(20, 400)\n",
    "        self.fc4 = nn.Linear(400, 784)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparametrize(self, mu, logvar):\n",
    "        std = logvar.mul(0.5).exp_()\n",
    "        eps = torch.FloatTensor(std.size()).normal_()\n",
    "        if torch.cuda.is_available():\n",
    "            eps = Variable(eps.cuda())\n",
    "        else:\n",
    "            eps = Variable(eps)\n",
    "        return eps.mul(std).add_(mu)\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return F.tanh(self.fc4(h3))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x) # 编码\n",
    "        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布\n",
    "        return self.decode(z), mu, logvar # 解码，同时输出均值方差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c892e68d",
   "metadata": {
    "papermill": {
     "duration": 0.225375,
     "end_time": "2024-06-27T02:45:31.633817",
     "exception": false,
     "start_time": "2024-06-27T02:45:31.408442",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:31:25.371814600Z",
     "start_time": "2024-10-23T02:31:25.364741600Z"
    }
   },
   "outputs": [],
   "source": [
    "net = VAE() # 实例化网络\n",
    "if torch.cuda.is_available():\n",
    "    net = net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "95cbbf69",
   "metadata": {
    "papermill": {
     "duration": 0.561288,
     "end_time": "2024-06-27T02:45:32.200936",
     "exception": false,
     "start_time": "2024-06-27T02:45:31.639648",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:31:27.672713500Z",
     "start_time": "2024-10-23T02:31:27.647755500Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.1687,  0.0574, -0.0285, -0.3483,  0.0950, -0.3467,  0.0298,  0.4671,\n",
      "         -0.0170, -0.1701,  0.4428, -0.0195, -0.0497, -0.0503,  0.6589,  0.0782,\n",
      "         -0.0719, -0.6350,  0.1272, -0.1256]], grad_fn=<AddmmBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\admin\\AppData\\Roaming\\Python\\Python310\\site-packages\\torch\\nn\\functional.py:1956: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
      "  warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
     ]
    }
   ],
   "source": [
    "x, _ = train_set[0]\n",
    "x = x.view(x.shape[0], -1)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "x = Variable(x)\n",
    "_, mu, var = net(x)\n",
    "print(mu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fc8ab604",
   "metadata": {
    "papermill": {
     "duration": 0.022903,
     "end_time": "2024-06-27T02:45:32.229708",
     "exception": false,
     "start_time": "2024-06-27T02:45:32.206805",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:31:37.682439700Z",
     "start_time": "2024-10-23T02:31:37.676910100Z"
    }
   },
   "outputs": [],
   "source": [
    "reconstruction_function = nn.MSELoss(reduction='sum')\n",
    "\n",
    "def loss_function(recon_x, x, mu, logvar):\n",
    "    \"\"\"\n",
    "    recon_x: generating images\n",
    "    x: origin images\n",
    "    mu: latent mean\n",
    "    logvar: latent log variance\n",
    "    \"\"\"\n",
    "    MSE = reconstruction_function(recon_x, x)\n",
    "    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n",
    "    KLD = torch.sum(KLD_element).mul_(-0.5)\n",
    "    # KL divergence\n",
    "    return MSE + KLD\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "81553358",
   "metadata": {
    "papermill": {
     "duration": 0.015467,
     "end_time": "2024-06-27T02:45:32.251280",
     "exception": false,
     "start_time": "2024-06-27T02:45:32.235813",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:31:40.973746700Z",
     "start_time": "2024-10-23T02:31:40.968203700Z"
    }
   },
   "outputs": [],
   "source": [
    "def to_img(x):\n",
    "    '''\n",
    "    定义一个函数将最后的结果转换回图片\n",
    "    '''\n",
    "    x = 0.5 * (x + 1.)\n",
    "    x = x.clamp(0, 1)\n",
    "    x = x.view(x.shape[0], 1, 28, 28)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "63b68158",
   "metadata": {
    "papermill": {
     "duration": 1477.573568,
     "end_time": "2024-06-27T03:10:09.831022",
     "exception": false,
     "start_time": "2024-06-27T02:45:32.257454",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:33:30.288805800Z",
     "start_time": "2024-10-23T02:31:44.717301800Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, Loss: 85.7441\n",
      "epoch: 2, Loss: 81.6198\n",
      "epoch: 3, Loss: 73.2632\n",
      "epoch: 4, Loss: 68.6567\n",
      "epoch: 5, Loss: 70.9267\n",
      "训练用时105.56150674819946\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "for e in range(5):\n",
    "    for im, _ in train_data:\n",
    "        im = im.view(im.shape[0], -1)\n",
    "        im = Variable(im)\n",
    "        if torch.cuda.is_available():\n",
    "            im = im.cuda()\n",
    "        recon_im, mu, logvar = net(im)\n",
    "        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    if (e + 1) % 1 == 0:\n",
    "        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))\n",
    "        save = to_img(recon_im.cpu().data)\n",
    "        if not os.path.exists('./vae_img'):\n",
    "            os.mkdir('./vae_img')\n",
    "        save_image(save, './vae_img/image_{}.png'.format(e + 1))\n",
    "        \n",
    "        \n",
    "end = time.time()\n",
    "print(f\"训练用时{end-start}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4945a186",
   "metadata": {
    "papermill": {
     "duration": 0.020103,
     "end_time": "2024-06-27T03:10:09.857686",
     "exception": false,
     "start_time": "2024-06-27T03:10:09.837583",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:35:06.503076600Z",
     "start_time": "2024-10-23T02:35:06.357444500Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.3484, -0.2448,  0.0817,  0.0212,  0.6749,  1.0675,  0.0067,  2.0301,\n",
      "         -1.0485, -0.3447,  0.4381,  1.2931,  0.7755, -0.6412,  1.2218, -0.8677,\n",
      "          0.1716,  1.5818,  0.9057,  1.3432]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "x, _ = train_set[0]\n",
    "x = x.view(x.shape[0], -1)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "x = Variable(x)\n",
    "_, mu, _ = net(x)\n",
    "print(mu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "b03e701ada614bb3"
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "none",
   "dataSources": [],
   "dockerImageVersionId": 30733,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 1491.585132,
   "end_time": "2024-06-27T03:10:11.391316",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2024-06-27T02:45:19.806184",
   "version": "2.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
